{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Converting a Huggingface model to ONNX with tf2onnx\n",
    "\n",
    "This is a simple example how to convert a [huggingface](https://huggingface.co/) model to ONNX using [tf2onnx](https://github.com/onnx/tensorflow-onnx).\n",
    "\n",
    "We use the [TFBertForQuestionAnswering](https://huggingface.co/transformers/model_doc/bert.html#tfbertforquestionanswering) example from huggingface.\n",
    "\n",
    "Other models will work similar. You'll find additional examples for other models in our unit tests [here](https://github.com/onnx/tensorflow-onnx/blob/master/tests/huggingface.py)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install tensorflow transformers tf2onnx onnxruntime"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The keras code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = \"\"\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import numpy as np\n",
    "import onnxruntime as rt\n",
    "import tensorflow as tf\n",
    "import tf2onnx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": []
    },
    {
     "data": {
      "text/plain": [
       "TFQuestionAnsweringModelOutput(loss=None, start_logits=<tf.Tensor: shape=(1, 16), dtype=float32, numpy=\n",
       "array([[ 0.27443457,  0.02250022, -0.32903647, -0.32448006, -0.26440915,\n",
       "        -0.03356116, -0.11466929, -0.12272861, -0.23254037, -0.21369037,\n",
       "         0.02170385, -0.38734213, -0.14865303, -0.04804918,  0.02706608,\n",
       "        -0.12273058]], dtype=float32)>, end_logits=<tf.Tensor: shape=(1, 16), dtype=float32, numpy=\n",
       "array([[-0.23549399,  0.11830041, -0.16875415,  0.04315909,  0.00721513,\n",
       "         0.20957005,  0.00850991, -0.49158442,  0.10791501,  0.07153591,\n",
       "         0.26274043, -0.15160318, -0.01847767,  0.03389414,  0.25666913,\n",
       "        -0.49158433]], dtype=float32)>, hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import BertTokenizer, TFBertForQuestionAnswering\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n",
    "model = TFBertForQuestionAnswering.from_pretrained('bert-base-cased')\n",
    "question, text = \"Who was Jim Henson?\", \"Jim Henson was a nice puppet\"\n",
    "input_dict = tokenizer(question, text, return_tensors='tf')\n",
    "tf_results = model(input_dict)\n",
    "tf_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert to ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": []
    }
   ],
   "source": [
    "# describe the inputs\n",
    "input_spec = (\n",
    "    tf.TensorSpec((None,  None), tf.int32, name=\"input_ids\"),\n",
    "    tf.TensorSpec((None,  None), tf.int32, name=\"token_type_ids\"),\n",
    "    tf.TensorSpec((None,  None), tf.int32, name=\"attention_mask\")\n",
    ")\n",
    "\n",
    "# and convert\n",
    "_, _ = tf2onnx.convert.from_keras(model, input_signature=input_spec, opset=13, output_path=\"bert.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the ONNX model with onnxruntime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[ 0.27443478,  0.02250013, -0.32903633, -0.32448038, -0.26440892,\n",
       "         -0.03356095, -0.11466938, -0.12272887, -0.2325401 , -0.21369015,\n",
       "          0.02170385, -0.3873423 , -0.148653  , -0.04804894,  0.02706566,\n",
       "         -0.1227307 ]], dtype=float32),\n",
       " array([[-0.23549382,  0.11830062, -0.16875397,  0.0431588 ,  0.00721494,\n",
       "          0.2095699 ,  0.00850987, -0.49158436,  0.10791501,  0.07153573,\n",
       "          0.26274025, -0.15160298, -0.01847767,  0.03389416,  0.25666922,\n",
       "         -0.49158415]], dtype=float32)]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# get the names we want as output\n",
    "output_names = list(tf_results.keys())\n",
    "\n",
    "# switch the input_dict to numpy\n",
    "input_dict_np = {k: v.numpy() for k, v in input_dict.items()}\n",
    "\n",
    "opt = rt.SessionOptions()\n",
    "sess = rt.InferenceSession(\"bert.onnx\")\n",
    "onnx_results = sess.run(output_names, input_dict_np)\n",
    "onnx_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make sure tensorflow and onnxruntime results are the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, name in enumerate(output_names):\n",
    "    np.testing.assert_allclose(tf_results[name], onnx_results[i], rtol=1e-5, atol=1e-5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:root] *",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
